// Copyright 2025 NVIDIA CORPORATION
// SPDX-License-Identifier: Apache-2.0

package log

import (
	"fmt"
	"hash/fnv"

	"go.uber.org/zap"
	"go.uber.org/zap/buffer"
	"go.uber.org/zap/zapcore"
)

const (
	sessionIDField = "sessionID"
	actionField    = "action"
	minColorCode   = 1
)

// SchedulerLogger is used to wrap other loggers with verbosity level logging similar to glog
type SchedulerLogger interface {
	V(int) *zap.SugaredLogger
	Warningf(string, ...interface{})
	Errorf(string, ...interface{})
	Fatalf(string, ...interface{})
	SetSessionID(string)
	SetAction(string)
	RemoveActionLogger()

	Sync() error
}

type schedulerLogger struct {
	logLevel      int
	sessionID     string
	actionName    string
	baseLogger    *zap.SugaredLogger
	sessionLogger *zap.SugaredLogger
	actionLogger  *zap.SugaredLogger
}

var emptyLogger = zap.NewNop().Sugar()

// InfraLogger should be used for logs generated by scheduler actions/phases
var InfraLogger SchedulerLogger = &schedulerLogger{logLevel: 3, sessionLogger: emptyLogger}

// StatusUpdaterLogger should be used for logs generated by status updater
var StatusUpdaterLogger SchedulerLogger = &schedulerLogger{logLevel: 3, sessionLogger: emptyLogger}

func (sl *schedulerLogger) V(lvl int) *zap.SugaredLogger {
	if sl.logLevel >= lvl {
		return sl.getLogger()
	}
	return emptyLogger
}

func (sl *schedulerLogger) Warningf(t string, vars ...interface{}) {
	sl.getLogger().Desugar().WithOptions(zap.AddCallerSkip(1)).Sugar().Warnf(t, vars...)
}

func (sl *schedulerLogger) Errorf(t string, vars ...interface{}) {
	sl.getLogger().Desugar().WithOptions(zap.AddCallerSkip(1)).Sugar().Errorf(t, vars...)
}

func (sl *schedulerLogger) Fatalf(t string, vars ...interface{}) {
	sl.getLogger().Desugar().WithOptions(zap.AddCallerSkip(1)).Sugar().Fatalf(t, vars...)
}

func (sl *schedulerLogger) Sync() error {
	return sl.getLogger().Sync()
}

func (sl *schedulerLogger) SetSessionID(sessionID string) {
	sl.sessionID = sessionID
	sl.sessionLogger = sl.baseLogger.With(sessionIDField, sessionID)
	sl.RemoveActionLogger()
}

func (sl *schedulerLogger) SetAction(actionName string) {
	sl.actionName = actionName
	sl.actionLogger = sl.sessionLogger.With(sessionIDField, sl.sessionID, actionField, actionName)
}

func (sl *schedulerLogger) RemoveActionLogger() {
	sl.actionLogger = nil
}

func (sl *schedulerLogger) getLogger() *zap.SugaredLogger {
	if sl.actionLogger != nil {
		return sl.actionLogger
	}
	if sl.sessionLogger != nil {
		return sl.sessionLogger
	}
	return sl.baseLogger
}

func newSchedulerLogger(logLevel int, logger *zap.SugaredLogger) SchedulerLogger {
	return &schedulerLogger{logLevel: logLevel, baseLogger: logger}
}

type sessionIDEncoder struct {
	zapcore.Encoder
	sessionID  string
	actionName string
}

func (enc *sessionIDEncoder) EncodeEntry(entry zapcore.Entry, fields []zapcore.Field) (*buffer.Buffer, error) {
	var sessionIDFieldInd int = -1
	for ind, field := range fields {
		if field.Key == sessionIDField {
			sessionIDFieldInd = ind
			entry.Message = fmt.Sprintf("[%s] %s", field.String, entry.Message)
		}
	}
	if sessionIDFieldInd > -1 {
		fields = append(fields[:sessionIDFieldInd], fields[sessionIDFieldInd+1:]...)
	} else {
		if enc.sessionID != "" && enc.actionName != "" {
			entry.Message = fmt.Sprintf("[%s] [%s] %s", enc.sessionID, enc.actionName, entry.Message)
		} else if enc.sessionID != "" {
			entry.Message = fmt.Sprintf("[%s] %s", enc.sessionID, entry.Message)
		} else if enc.actionName != "" {
			entry.Message = fmt.Sprintf("[%s] %s", enc.actionName, entry.Message)
		}
	}
	return enc.Encoder.EncodeEntry(entry, fields)
}

func (enc *sessionIDEncoder) Clone() zapcore.Encoder {
	return &sessionIDEncoder{
		Encoder: enc.Encoder.Clone(),
	}
}

func (enc *sessionIDEncoder) AddString(key, value string) {
	if key == sessionIDField {
		enc.sessionID = wrapTextWithColor(value)
	} else if key == actionField {
		enc.actionName = wrapTextWithColor(value)
	} else {
		enc.Encoder.AddString(key, value)
	}
}

func wrapTextWithColor(txt string) string {
	colorCode := hash(txt)
	if colorCode < minColorCode {
		colorCode = minColorCode
	}
	return fmt.Sprintf("%s%d%s%s%s", "\033[3", colorCode, "m", txt, "\033[0m")
}

func hash(s string) uint32 {
	h := fnv.New32a()
	h.Write([]byte(s))
	return h.Sum32() % 8
}

func InitLoggers(logLevel int) error {
	if err := zap.RegisterEncoder("sessionID", func(cfg zapcore.EncoderConfig) (zapcore.Encoder, error) {
		return &sessionIDEncoder{
			Encoder: zapcore.NewConsoleEncoder(cfg),
		}, nil
	}); err != nil {
		return err
	}

	baseLoggerConfig := zap.NewProductionConfig()
	baseLoggerConfig.Encoding = "sessionID"
	baseLoggerConfig.DisableStacktrace = true
	baseLoggerConfig.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
	baseLoggerConfig.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder

	baseLogger, err := baseLoggerConfig.Build()
	if err != nil {
		return err
	}
	logger := baseLogger.WithOptions().Sugar()
	InfraLogger = newSchedulerLogger(logLevel, logger)
	StatusUpdaterLogger = newSchedulerLogger(logLevel, logger)
	StatusUpdaterLogger.SetSessionID("status-updater")
	return nil
}
